
import numpy as np 
import torch 
import h5py 
import matplotlib.pyplot as plt 

if __name__ == "__main__":
    
    out_file = h5py.File("./output_dir/out_brain_data20t1.h5")
    img_file = h5py.File("./self_data_val/brain_data20t1.h5")

    out_img = out_file["label"][()]
    img = img_file["raw"][()]
    label = img_file["label"][()]
    print(out_img.shape)

    for i in range(out_img.shape[1]):
        plt.subplot(1, 3, 1)
        plt.imshow(img[i], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.imshow(out_img[0, i], cmap="gray")
        plt.subplot(1, 3, 3)
        plt.imshow(label[i], cmap="gray")
        plt.show()